library(dplyr)
library(tidyr)
library(raster)
library(Matrix)
library(sparseMatrixStats)
library(Hmisc)
library(seegSDM)

library(ggplot2)
run_unique_name <- "FS_data_1000"
# save this file before run. and ensure variable above is set correctly (before and after!)
rmarkdown::render('code_ruarai/R_interactive/summarise_model_run.Rmd',
                  output_file = paste0('reports/report_', run_unique_name, '.html'))
pred_files <- list.files(paste0("output/update/predictions/", run_unique_name, "/"), full.names=TRUE)

preds <- lapply(pred_files, function(x){
  pred_values <- readRDS(x)
  
  pred_values[is.na(pred_values)] <- 0
  
  return(Matrix::t(as(pred_values, "sparseMatrix")))
})
saveRDS(do.call(rbind, preds), 
        file = paste0("output/update/", run_unique_name , "_pred_matrix.Rds"),
        compress=FALSE)
full_pred_matrix <- readRDS(file = paste0("output/update/", run_unique_name, "_pred_matrix.Rds"))
blank_seasia <- raster('data/clean/raster/SEAsia_extent.grd')
risk_freya <- raster('data/clean/raster/SEAsia.tif')

cols <- colorRampPalette(c("#55843b", "#a4be79","#ffffbf", "#921d67", "#6c0043"))(100)
pred_means <- colMeans2(full_pred_matrix)
pred_sd <- colSds(full_pred_matrix)

pred_means[pred_means == 0] <- NA
pred_sd[pred_sd == 0] <- NA

plot(setValues(blank_seasia, pred_means), col = cols, zlim=c(0,1))
title("Predicted risk means", run_unique_name)

plot(setValues(blank_seasia, pred_sd))
title("Predicted risk standard deviation", run_unique_name)

diff_means <- pred_means - risk_freya

plot(diff_means, col=cols)
title("Change in risk", paste0("FS's SEAsia.tif vs ", run_unique_name))

hist(diff_means, xlab='Change in risk from original FS risk map')
title("", run_unique_name)

model_stats <- do.call(rbind, 
                       lapply(list.files(paste0("output/update/model_stats/", run_unique_name)
                                         , full.names = TRUE),
                              read.csv))
model_stats <- model_stats[,-1]


print(paste0("CV AUC mean: ", round(mean(model_stats$auc),2)))
## [1] "CV AUC mean: 0.78"
print(paste0("CV RMSE mean: ", round(mean(model_stats$rmse),2)))
## [1] "CV RMSE mean: 0.45"
model_rel_inf <- do.call(rbind, 
                       lapply(list.files(paste0("output/update/model_rel_inf/", run_unique_name)
                                         , full.names = TRUE),
                              read.csv))
model_rel_inf <- model_rel_inf[,-1]
model_rl_long <- model_rel_inf %>% pivot_longer(cols = everything())

means <- model_rl_long %>%
  group_by(name) %>%
  summarise(mean_rl = mean(value)) %>%
  arrange(desc(mean_rl))
## `summarise()` ungrouping output (override with `.groups` argument)
model_rl_long$name <- factor(model_rl_long$name, levels = means$name)

ggplot(model_rl_long) +
  geom_jitter(aes(x = name, y = value), size=0.2, width=0.2, alpha=0.5) +
  theme(axis.text.x = element_text(angle = -45,hjust=0)) +
  xlab("Covariate") + ylab("Relative Influence")

ggplot(model_rl_long) +
  geom_jitter(aes(x = name, y = value), size=0.2, width=0.2, alpha=0.5) +
  scale_y_continuous(trans = 'log10')+
  theme(axis.text.x = element_text(angle = -45,hjust=0)) +
  xlab("Covariate") + ylab("Relative Influence (log10 adjusted)")
## Warning: Transformation introduced infinite values in continuous y-axis

effect_plot <- readRDS(paste0("output/update/", run_unique_name , "_effect_plot.Rds"))

ggplot(effect_plot) +
  geom_point(aes(x = x, y = value), alpha=0.05, size=0.1) +
  facet_wrap(~var_name)